{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读写文件\n",
    "\n",
    "到目前为止，我们讨论了如何处理数据，以及如何构建、训练和测试深度学习模型。然而，有时我们对所学的模型足够满意，我们希望保存训练的模型以备将来在各种环境中使用（可能部署进行预测）。此外，当运行一个耗时较长的训练过程时，最佳实践是定期保存中间结果（检查点），以确保在服务器电源被不小心断掉时不会损失几天的计算结果。因此，现在是时候学习如何加载和存储权重向量和整个模型。本节将讨论这些问题。\n",
    "\n",
    "## 加载和保存 NDArray\n",
    "\n",
    "对于单个 `NDArray`，我们可以直接调用 `encode()` 和 `decode()` 函数分别读写它们。为我们可以使用这两个函数结合 `FileInputStream`, `FileOutputStream` 来实现对文件的读写。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "NDArray x = manager.arange(4);\n",
    "try (FileOutputStream fos = new FileOutputStream(\"x-file\")) {\n",
    "    fos.write(x.encode());\n",
    "}\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们现在可以将存储在文件中的数据读回内存。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray x2;\n",
    "try (FileInputStream fis = new FileInputStream(\"x-file\")) {\n",
    "    // We use the `Utils` method `toByteArray()` to read \n",
    "    // from a `FileInputStream` and return it as a `byte[]`.\n",
    "    x2 = NDArray.decode(manager, Utils.toByteArray(fis));\n",
    "}\n",
    "x2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也可以存储一个 `NDList`，然后把它们读回内存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList list = new NDList(x, x2);\n",
    "try (FileOutputStream fos = new FileOutputStream(\"x-file\")) {\n",
    "    fos.write(list.encode());\n",
    "}\n",
    "try (FileInputStream fis = new FileInputStream(\"x-file\")) {\n",
    "    list = NDList.decode(manager, Utils.toByteArray(fis));\n",
    "}\n",
    "list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载和保存模型参数\n",
    "\n",
    "保存单个权重向量（或其他`NDArray`）确实是有用的，但是如果我们想保存整个模型，并在以后加载它们。单独保存每个向量则会变得很麻烦。毕竟，我们可能有数百个参数散布在各处。因此，深度学习框架提供了内置函数来保存和加载整个网络。需要注意的一个重要细节是，这将保存模型的参数而不是保存整个模型。例如，如果我们有一个3层多层感知机，我们需要单独指定结构。因为模型本身可以包含任意代码，所以模型本身难以序列化。因此，为了恢复模型，我们需要用代码生成结构，然后从磁盘加载参数。让我们从熟悉的多层感知机开始尝试一下。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public SequentialBlock createMLP() {\n",
    "    SequentialBlock mlp = new SequentialBlock();\n",
    "    mlp.add(Linear.builder().setUnits(256).build());\n",
    "    mlp.add(Activation.reluBlock());\n",
    "    mlp.add(Linear.builder().setUnits(10).build());\n",
    "    return mlp;\n",
    "}\n",
    "\n",
    "SequentialBlock original = createMLP();\n",
    "\n",
    "NDArray x = manager.randomUniform(0, 1, new Shape(2, 5));\n",
    "\n",
    "original.initialize(manager, DataType.FLOAT32, x.getShape());\n",
    "\n",
    "ParameterStore ps = new ParameterStore(manager, false);\n",
    "NDArray y = original.forward(ps, new NDList(x), false).singletonOrThrow();\n",
    "\n",
    "y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们将模型的参数存储为一个叫做“mlp.params”的文件："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Save file\n",
    "File mlpParamFile = new File(\"mlp.param\");\n",
    "DataOutputStream os = new DataOutputStream(Files.newOutputStream(mlpParamFile.toPath()));\n",
    "original.saveParameters(os);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了恢复模型，我们实例化了原始多层感知机模型。我们没有随机初始化模型参数，而是直接读取文件中存储的参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Create duplicate of network architecture\n",
    "SequentialBlock clone = createMLP();\n",
    "// Load Parameters\n",
    "clone.loadParameters(manager, new DataInputStream(Files.newInputStream(mlpParamFile.toPath())));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于两个实例具有相同的模型参数，在输入相同的`X`时，两个实例的计算结果应该相同。让我们来验证一下。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Original model's parameters\n",
    "PairList<String, Parameter> originalParams = original.getParameters();\n",
    "// Loaded model's parameters\n",
    "PairList<String, Parameter> loadedParams = clone.getParameters();\n",
    "\n",
    "for (int i = 0; i < originalParams.size(); i++) {\n",
    "    if (originalParams.valueAt(i).getArray().equals(loadedParams.valueAt(i).getArray())) {\n",
    "        System.out.printf(\"True \");\n",
    "    } else {\n",
    "        System.out.printf(\"False \");\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray yClone = clone.forward(ps, new NDList(x), false).singletonOrThrow();\n",
    "\n",
    "y.equals(yClone);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* `encode()` 和 `decode()` 函数可用于 `NDArray` 和 `NDList` 对象的文件读写。\n",
    "* 我们可以保存和加载模型文件。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 即使不需要将经过训练的模型部署到不同的设备上，存储模型参数还有什么实际的好处？\n",
    "1. 假设我们只想复用网络的一部分，以将其合并到不同的网络结构中。比如说，如果你想在一个新的网络中使用之前网络的前两层，你该怎么做？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
